import os.path
import sys, importlib

import argparse

import numpy as np

import torch
from torch.nn import functional as F
import torch.nn as nn
import torch.optim as optim
from torchvision.utils import save_image
import torch.utils.data as data_utils

from util import *

sys.path.append('')
from mnist_loader import MnistRotated
import datetime

from feddirt import Central




if __name__ == "__main__":
    # Training settings
    parser = argparse.ArgumentParser(description='FedDIRT')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--seed', type=int, default=0,
                        help='random seed (default: 0)')
    parser.add_argument('--batch-size', type=int, default=64,
                        help='input batch size for training (default: 64)')
    # parser.add_argument('--epochs', type=int, default=30,
    #                     help='number of epochs to train (default: 10)')
    parser.add_argument('--iters', type=int, default=1500,
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr', type=float, default=0.001,
                        help='learning rate (default: 0.001)')
    parser.add_argument('--num-supervised', default=1000, type=int,
                        help="number of supervised examples, /10 = samples per class")

    # Basic setting
    parser.add_argument('--list_train_domains', type=list, default=['0', '15', '30', '45', '60', '75'],
                        help='domains used during training')
    parser.add_argument('--target_domain', type=str, default='75',
                        help='domain used during testing')
    parser.add_argument('--model', type=str, default='dirt')
    parser.add_argument('--dataset', type=str, default='RotatedMnist')

    # StarGAN Model
    parser.add_argument('--d-dim', type=int, default=5,
                        help='number of classes')
    parser.add_argument('--x-dim', type=int, default=784,
                        help='input size after flattening')
    parser.add_argument('--y-dim', type=int, default=10,
                        help='number of classes')
    parser.add_argument('--zd-dim', type=int, default=64,
                        help='size of latent space 1')
    parser.add_argument('--zx-dim', type=int, default=64,
                        help='size of latent space 2')
    parser.add_argument('--zy-dim', type=int, default=64,
                        help='size of latent space 3')

    # Aux multipliers
    parser.add_argument('--aux_loss_multiplier_y', type=float, default=3500.,
                        help='multiplier for y classifier')
    parser.add_argument('--aux_loss_multiplier_d', type=float, default=2000.,
                        help='multiplier for d classifier')
    # Beta VAE part
    parser.add_argument('--beta_d', type=float, default=1.,
                        help='multiplier for KL d')
    parser.add_argument('--beta_x', type=float, default=1.,
                        help='multiplier for KL x')
    parser.add_argument('--beta_y', type=float, default=1.,
                        help='multiplier for KL y')


    parser.add_argument('-w', '--warmup', type=int, default=100, metavar='N',
                        help='number of epochs for warm-up. Set to 0 to turn warmup off.')
    parser.add_argument('--max_beta', type=float, default=1., metavar='MB',
                        help='max beta for warm-up')
    parser.add_argument('--min_beta', type=float, default=0.0, metavar='MB',
                        help='min beta for warm-up')


    # INB
    parser.add_argument('--nlayer', type=int, default=10,
                        help='Number of INB layers')
    parser.add_argument('--k', type=int, default=10,
                        help='K')
    parser.add_argument('--max_swd_iters', type=int, default=100,
                        help='Jmax')
    parser.add_argument('--use_shared', action='store_true', default=False,
                        help='Use shared space of INB')
    # parser.add_argument('--hist', action='store_true', default=False,
    #                     help='Use hist version of INB')
    parser.add_argument('--hist_bins', default=None)
    # AE
    parser.add_argument('--activation', default='sigmoid')
    parser.add_argument('--ae_dir', default='')

    # DIRT training
    #parser.add_argument('--trans', type=str, default='inb')
    parser.add_argument('--mnist_subset', type=str, default='0')
    parser.add_argument('--all-data', action='store_true', default=False,
                        help='whether to use all MNIST in the training')
    parser.add_argument('--sync_step', type=int, default=1)
    parser.add_argument('--eval_step', type=int, default=50)

    parser.add_argument('--extra', action='store_true',default=False)
    parser.add_argument('--reg',default=1,type=float)


    # log
    parser.add_argument('--data_dir', type=str, default='')
    parser.add_argument('--model_dir', type=str, default='')
    parser.add_argument('--trans',type=str,default='stargan')
    parser.add_argument('--tn', type=str, default='inb')
    parser.add_argument('--outpath', type=str, default='./saved/',
                        help='where to save')
    parser.add_argument('--note',default='')

    args = parser.parse_args()
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device(f"cuda:{args.device}" if args.cuda else "cpu")
    args.device = device
    kwargs = {'num_workers': 1, 'pin_memory': False} if args.cuda else {}

    # Set seed
    torch.manual_seed(args.seed)
    torch.backends.cudnn.benchmark = False
    np.random.seed(args.seed)

    # =================================================================================== #
    #                                    Saving and logging                               #
    # =================================================================================== #
    # if args.hist_bins is not None:
    #     if args.use_shared:
    #         run_name = f'c-hist{args.trans}-shared-{args.target_domain}{args.seed}'
    #     else:
    #         run_name = f'c-hist{args.trans}-{args.target_domain}{args.seed}'
    # else:
    #     if args.use_shared:
    #         run_name = f'c-{args.trans}-shared-{args.target_domain}{args.seed}'
    #     else:
    #         run_name = f'c-{args.trans}-{args.target_domain}{args.seed}'
    # if args.trans != 'stargan':
    #     run_name += f'-{args.nlayer}-{args.k}-{args.max_swd_iters}'

    # example: model = 'torchinb/10_10_200'
    if args.trans != 'stargan':
        args.model_dir = args.model_dir + args.trans + f'/{75}/inb.pt'
    run_name = args.trans
    run_name += args.note
    if args.use_shared:
        run_name += 'shared'
    run_name += f'_r{args.reg}'
    run_name += f'_{args.note}'


    # Model name
    print(args.outpath)
    model_name = args.outpath + args.trans + '_domain_'+ str(args.target_domain) +  '_seed_' + str(args.seed)
    print(model_name)

    # # Log dir
    # log_dir = args.outpath + 'log/'
    # if not os.path.exists(log_dir):
    #     os.makedirs(log_dir)
    # #log_name = log_dir + datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') + '.txt'
    # if args.trans != 'stargan':
    #     log_name = log_dir + args.model + '_' + args.trans + '_' +\
    #                str(args.target_domain) + f'_subset{args.mnist_subset}' +\
    #                '_seed_' + str(args.seed)  +\
    #                f'_{args.nlayer}_{args.k}_{args.max_swd_iters}' + '.txt'
    #
    # else:
    #     log_name = log_dir + args.model + '_' + args.trans + '_' +\
    #                str(args.target_domain) + f'_subset{args.mnist_subset}' +\
    #                '_seed_' + str(args.seed)  + '.txt'
    # log_txt = open(log_name,'w+')


    # =================================================================================== #
    #                                     Prepare data                                    #
    # =================================================================================== #
    # Choose training domains
    #all_training_domains = ['0', '15', '30', '45', '60', '75']
    all_training_domains = ['0', '15', '30', '45', '60']
    #all_training_domains.remove(args.target_domain)
    args.list_train_domains = all_training_domains

    print(args.target_domain, args.list_train_domains)
    args.n_domains = len(args.list_train_domains)

    print(all_training_domains)
    train_loader_dict = dict()
    for i,domain in enumerate(all_training_domains):

        train_set = MnistRotated([domain], [args.target_domain], args.data_dir,
                                 train=True, mnist_subset=args.mnist_subset, all_data=args.all_data)
        # change the domain label
        train_set.train_domain = torch.ones_like(train_set.train_domain) * i

        train_loader = data_utils.DataLoader(train_set,
                                             batch_size=args.batch_size,
                                             shuffle=True, **kwargs)
        train_loader_dict[domain] = train_loader


    test_set = MnistRotated(args.list_train_domains, [args.target_domain], args.data_dir,
                            train=False, mnist_subset=args.mnist_subset, all_data=args.all_data)
    test_loader = data_utils.DataLoader(test_set,
                                        batch_size=args.batch_size,
                                        shuffle=True, **kwargs)
    #train_loader_dict['all'] = train_loader
    print(train_set.__len__())
    print(test_set.__len__())

    # =================================================================================== #
    #                                     Prepare Model                                   #
    # =================================================================================== #

    activations = {'tanh':nn.Tanh(),
                   'sigmoid':nn.Sigmoid()}
    args.activation = activations[args.activation]

    # if args.trans == 'aeinb':
    #     args.ae_dir += '/centralae'
    # elif args.trans == 'indaeinb':
    #     args.ae_dir += '/indae'
    args.ae_dir += '/indae'

    #args.list_train_domains = ['all']
    # setup the DIRT

    central = Central(train_loader_dict,test_loader,args)



    tracker = central.train()
    save_name = args.trans.replace('/','-') + args.note
    if args.use_shared:
        save_name += 'share'
    torch.save(tracker,f'./saved/{save_name}.pt')

